package LiuYun

import spinal.core._
import spinal.lib._

import spinal.lib.bus.amba4.axi._

class LabAxi3W(config:Axi4Config) extends Axi4W(config){
    val id:UInt = UInt(4 bits)
}
object LabAxi3W{
    def apply(config:Axi4Config) = new LabAxi3W(config)
}
class LabAxi3 extends Bundle with IMasterSlave{
    val config = new Axi4Config(32,32,
            idWidth=4,
            useRegion=false,
            useQos=false)
    val aw = Stream(   Axi4Aw(config))
    val w  = Stream(LabAxi3W (config))
    val b  = Stream(   Axi4B (config))
    val ar = Stream(   Axi4Ar(config))
    val r  = Stream(   Axi4R (config))
    override def asMaster():Unit = {
        master(ar,aw,w)
        slave (r,b)
    }
    def rename(channel:Bundle,prefix:String){
        for(element <- channel.flatten){
            def replace(pat:String,res:String):Unit = {
                val prev:String = element.getName()
                val name:String = prev.replace(pat,res)
                element.setName(name)
            }
            replace("_payload_","")
            replace("_valid","valid")
            replace("_ready","ready")
            if(prefix != "")replace(prefix,"")
        }
    }
    def renamed(prefix:String=""):Unit = {
        Array[Bundle](aw,w,b,ar,r).foreach(rename(_,prefix))
    }
}

class LabAxi3Arbiter extends Component{
    val io = new Bundle{
        val imiss:IMiss = slave(new IMiss)
        val dmiss:DMiss = slave(new DMiss)
        val axi  :LabAxi3 = master(new LabAxi3)
    }
    class ArEntry extends Bundle{
        val id   = UInt(4  bits)
        val addr = UInt(32 bits)
        val size = UInt(3  bits)
        val len  = UInt(8  bits)
        def assignedWith(id:UInt,addr:UInt,size:UInt,len:UInt):this.type = {
            this.id   := id  
            this.addr := addr
            this.size := size
            this.len  := len
            this
        }
    }
    class AwEntry extends Bundle{
        val addr = UInt(32 bits)
        val size = UInt( 3 bits)
        val len  = UInt( 8 bits)
        def assignedWith(addr:UInt,size:UInt,len:UInt):this.type = {
            this.addr := addr
            this.size := size
            this.len  := len
            this
        }
    }
    class WEntry extends Bundle{
        val datas = Vec(UInt(32 bits),4)
        val strb  = Bits( 4 bits)
        val num   = UInt( 4 bits)
        def data:Bits = datas(num(1 downto 0)).asBits
        def assignedWith(datas:Vec[UInt],strb:Bits):this.type = {
            this.datas := datas
            this.strb := strb
            this.num  := U(0)
            this
        }
    }
    val ar_send:Bool = RegInit(False)
    val ar_sent:Bool = RegInit(False)
    val aw_send:Bool = RegInit(False)
    val w_send :Bool = RegInit(False)
    val aw_sent:Bool = RegInit(False)
    val want_ld:Bool = io.dmiss.arvalid
    val want_st:Bool = io.dmiss.awvalid
    val want_fe:Bool = io.imiss.arvalid
    val recv_ld:Bool =  !aw_sent
    val recv_st:Bool = (!ar_sent || io.dmiss.line)
    val recv_fe:Bool =  !aw_sent && !ar_sent && !io.dmiss.occupy
    val take_ld:Bool = want_ld && recv_ld
    val take_st:Bool = want_st && recv_st
    val take_fe:Bool = want_fe && recv_fe
    val ar_buf :ArEntry = Reg(new ArEntry)
    val aw_buf :AwEntry = Reg(new AwEntry)
    val w_buf  :WEntry  = Reg(new WEntry)
    io.axi.ar.burst := B"10"
    io.axi.ar.lock  := B(0)
    io.axi.ar.cache := B(0)
    io.axi.ar.prot  := B(0)
    io.axi.aw.id    := U(1)
    io.axi.aw.burst := B"01"
    io.axi.aw.lock  := B(0)
    io.axi.aw.cache := B(0)
    io.axi.aw.prot  := B(0)
    io.axi.w.id     := U(1)
    val line_len:UInt = U(3, 8 bits)
    val word_len:UInt = U(0, 8 bits)
    def getLen(is_line:Bool):UInt = {
        Mux(is_line,line_len,word_len)
    }
    when(!ar_sent){
        when(take_ld){
            ar_send := True
            ar_sent := True
            ar_buf  := new ArEntry assignedWith(
                U(1),
                io.dmiss.araddr,
                io.dmiss.arsize,
                getLen(io.dmiss.line)
            )
        }.elsewhen(take_fe){
            ar_send := True
            ar_sent := True
            ar_buf := new ArEntry assignedWith(
                U(0),
                io.imiss.araddr,
                U(2),
                getLen(io.imiss.line)
            )
        }
    }.otherwise{
        when(io.axi.ar.ready){
            ar_send := False
        }
        when(io.axi.r.valid && io.axi.r.ready && io.axi.r.last){
            ar_sent := False
        }
    }
    when(!aw_sent){
        when(take_st){
            aw_send := True
            w_send  := True
            aw_sent := True
            aw_buf  := new AwEntry assignedWith(
                io.dmiss.awaddr,
                io.dmiss.awsize,
                getLen(io.dmiss.line)
            )
            w_buf   := new WEntry assignedWith(
                io.dmiss.wdata,
                io.dmiss.wstrb
            )
        }
    }.otherwise{
        when(io.axi.aw.ready){
            aw_send := False
        }
        when(io.axi.w.ready){
            when(io.axi.w.last){
                w_send := False
            }
            when(w_send){
                w_buf.num := w_buf.num + U(1)
            }
        }
        when(io.axi.b.valid){
            aw_sent := False
        }
    }
    io.axi.ar.payload.assignSomeByName(ar_buf)
    io.axi.aw.payload.assignSomeByName(aw_buf)
    io.axi.w.data := w_buf.data
    io.axi.w.strb := w_buf.strb
    io.axi.w.last := w_buf.num === aw_buf.len

    io.axi.ar.valid := ar_send
    io.axi.aw.valid := aw_send
    io.axi.w.valid  := w_send
    // assert that caches can always hold the response
    io.axi.b.ready := True
    io.axi.r.ready := True
    
    io.dmiss.arready := !aw_sent && !ar_sent
    io.imiss.arready := !aw_sent && !ar_sent && !io.dmiss.occupy
    io.dmiss.rvalid  := io.axi.r.valid && io.axi.r.id === U(1)
    io.imiss.rvalid  := io.axi.r.valid && io.axi.r.id === U(0)
    io.dmiss.rlast   := io.axi.r.last && io.axi.r.id === U(1)
    io.imiss.rlast   := io.axi.r.last && io.axi.r.id === U(0)
    io.imiss.rdata   := io.axi.r.data.asUInt
    io.dmiss.rdata   := io.axi.r.data.asUInt
    io.dmiss.bvalid  := io.axi.b.valid && io.axi.b.id === U(1)
}
